Super Resolution¶

1. Data Preparation¶

In [ ]:
# Import Libraries
import numpy as np
import tensorflow as tf
import keras
import cv2
from keras.models import Sequential
from tensorflow.keras.utils import img_to_array
import os
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
2023-05-07 15:13:31.680539: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

1.1 load data¶

In [ ]:
# to get the files in proper order
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
    return sorted(data,key = alphanum_key)
# defining the size of the image
SIZE = 256
high_img = []
path = 'Raw Data/high_res'
files = os.listdir(path)
files = sorted_alphanumeric(files)
for i in tqdm(files):
    if i == '855.jpg':
        break
    else:
        img = cv2.imread(path + '/'+i,1)
        # open cv reads images in BGR format so we have to convert it to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        #resizing image
        img = cv2.resize(img, (SIZE, SIZE))
        img = img.astype('float32') / 255.0
        high_img.append(img_to_array(img))


low_img = []
path = 'Raw Data/low_res'
files = os.listdir(path)
files = sorted_alphanumeric(files)
for i in tqdm(files):
     if i == '855.jpg':
        break
     else:
        img = cv2.imread(path + '/'+i,1)

        #resizing image
        img = cv2.resize(img, (SIZE, SIZE))
        img = img.astype('float32') / 255.0
        low_img.append(img_to_array(img))
100%|██████████| 855/855 [00:02<00:00, 402.42it/s]
100%|██████████| 855/855 [00:02<00:00, 409.31it/s]

1.2 Visualize the dataset¶

In [ ]:
for i in range(4):
    a = np.random.randint(0,855)
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.title('High Resolution Imge', color = 'green', fontsize = 20)
    plt.imshow(high_img[a])
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.title('low Resolution Image ', color = 'black', fontsize = 20)
    plt.imshow(low_img[a])
    plt.axis('off')
In [ ]:
train_high_image = high_img[:700]
train_low_image = low_img[:700]
train_high_image = np.reshape(train_high_image,(len(train_high_image),SIZE,SIZE,3))
train_low_image = np.reshape(train_low_image,(len(train_low_image),SIZE,SIZE,3))

validation_high_image = high_img[700:830]
validation_low_image = low_img[700:830]
validation_high_image= np.reshape(validation_high_image,(len(validation_high_image),SIZE,SIZE,3))
validation_low_image = np.reshape(validation_low_image,(len(validation_low_image),SIZE,SIZE,3))


test_high_image = high_img[830:]
test_low_image = low_img[830:]
test_high_image= np.reshape(test_high_image,(len(test_high_image),SIZE,SIZE,3))
test_low_image = np.reshape(test_low_image,(len(test_low_image),SIZE,SIZE,3))

print("Shape of training images:",train_high_image.shape)
print("Shape of test images:",test_high_image.shape)
print("Shape of validation images:",validation_high_image.shape)
Shape of training images: (700, 256, 256, 3)
Shape of test images: (25, 256, 256, 3)
Shape of validation images: (130, 256, 256, 3)

2. Benchmark Model - CNN¶

In [ ]:
from keras import layers
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import MultiHeadAttention

def down(filters , kernel_size, apply_batch_normalization = True):
    downsample = tf.keras.models.Sequential()
    downsample.add(layers.Conv2D(filters,kernel_size,padding = 'same', strides = 2))
    if apply_batch_normalization:
        downsample.add(layers.BatchNormalization())
    downsample.add(keras.layers.LeakyReLU())
    return downsample

def up(filters, kernel_size, dropout = False):
    upsample = tf.keras.models.Sequential()
    upsample.add(layers.Conv2DTranspose(filters, kernel_size,padding = 'same', strides = 2))
    if dropout:
        upsample.dropout(0.2)
    upsample.add(keras.layers.LeakyReLU())
    return upsample

def model():
    inputs = layers.Input(shape= [SIZE,SIZE,3])
    d1 = down(128,(3,3),False)(inputs)
    d2 = down(128,(3,3),False)(d1)
    d3 = down(256,(3,3),True)(d2)
    d4 = down(512,(3,3),True)(d3)

    d5 = down(512,(3,3),True)(d4)
    #upsampling
    u1 = up(512,(3,3),False)(d5)
    u1 = layers.concatenate([u1,d4])
    u2 = up(256,(3,3),False)(u1)
    u2 = layers.concatenate([u2,d3])
    u3 = up(128,(3,3),False)(u2)
    u3 = layers.concatenate([u3,d2])
    u4 = up(128,(3,3),False)(u3)
    u4 = layers.concatenate([u4,d1])
    u5 = up(3,(3,3),False)(u4)
    u5 = layers.concatenate([u5,inputs])
    output = layers.Conv2D(3,(2,2),strides = 1, padding = 'same')(u5)
    return tf.keras.Model(inputs=inputs, outputs=output)

model = model()
model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 sequential (Sequential)        (None, 128, 128, 12  3584        ['input_1[0][0]']                
                                8)                                                                
                                                                                                  
 sequential_1 (Sequential)      (None, 64, 64, 128)  147584      ['sequential[0][0]']             
                                                                                                  
 sequential_2 (Sequential)      (None, 32, 32, 256)  296192      ['sequential_1[0][0]']           
                                                                                                  
 sequential_3 (Sequential)      (None, 16, 16, 512)  1182208     ['sequential_2[0][0]']           
                                                                                                  
 sequential_4 (Sequential)      (None, 8, 8, 512)    2361856     ['sequential_3[0][0]']           
                                                                                                  
 sequential_5 (Sequential)      (None, 16, 16, 512)  2359808     ['sequential_4[0][0]']           
                                                                                                  
 concatenate (Concatenate)      (None, 16, 16, 1024  0           ['sequential_5[0][0]',           
                                )                                 'sequential_3[0][0]']           
                                                                                                  
 sequential_6 (Sequential)      (None, 32, 32, 256)  2359552     ['concatenate[0][0]']            
                                                                                                  
 concatenate_1 (Concatenate)    (None, 32, 32, 512)  0           ['sequential_6[0][0]',           
                                                                  'sequential_2[0][0]']           
                                                                                                  
 sequential_7 (Sequential)      (None, 64, 64, 128)  589952      ['concatenate_1[0][0]']          
                                                                                                  
 concatenate_2 (Concatenate)    (None, 64, 64, 256)  0           ['sequential_7[0][0]',           
                                                                  'sequential_1[0][0]']           
                                                                                                  
 sequential_8 (Sequential)      (None, 128, 128, 12  295040      ['concatenate_2[0][0]']          
                                8)                                                                
                                                                                                  
 concatenate_3 (Concatenate)    (None, 128, 128, 25  0           ['sequential_8[0][0]',           
                                6)                                'sequential[0][0]']             
                                                                                                  
 sequential_9 (Sequential)      (None, 256, 256, 3)  6915        ['concatenate_3[0][0]']          
                                                                                                  
 concatenate_4 (Concatenate)    (None, 256, 256, 6)  0           ['sequential_9[0][0]',           
                                                                  'input_1[0][0]']                
                                                                                                  
 conv2d_5 (Conv2D)              (None, 256, 256, 3)  75          ['concatenate_4[0][0]']          
                                                                                                  
==================================================================================================
Total params: 9,602,766
Trainable params: 9,600,206
Non-trainable params: 2,560
__________________________________________________________________________________________________
In [ ]:
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001), loss = 'mean_absolute_error',
              metrics = ['acc'])
In [ ]:
model.fit(train_low_image, train_high_image, epochs = 10, batch_size = 8,
          validation_data = (validation_low_image,validation_high_image))
Epoch 1/10
88/88 [==============================] - 3s 29ms/step - loss: 0.0260 - acc: 0.7720 - val_loss: 0.0249 - val_acc: 0.7927
Epoch 2/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0239 - acc: 0.8021 - val_loss: 0.0233 - val_acc: 0.7983
Epoch 3/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0225 - acc: 0.8010 - val_loss: 0.0245 - val_acc: 0.8107
Epoch 4/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0226 - acc: 0.8008 - val_loss: 0.0242 - val_acc: 0.7147
Epoch 5/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0220 - acc: 0.8127 - val_loss: 0.0216 - val_acc: 0.8674
Epoch 6/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0210 - acc: 0.8123 - val_loss: 0.0202 - val_acc: 0.8390
Epoch 7/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0211 - acc: 0.8179 - val_loss: 0.0212 - val_acc: 0.8720
Epoch 8/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0199 - acc: 0.8275 - val_loss: 0.0199 - val_acc: 0.8769
Epoch 9/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0199 - acc: 0.8299 - val_loss: 0.0192 - val_acc: 0.8069
Epoch 10/10
88/88 [==============================] - 2s 25ms/step - loss: 0.0191 - acc: 0.8316 - val_loss: 0.0204 - val_acc: 0.8225
Out[ ]:
<keras.callbacks.History at 0x7fd4d442cbe0>

3. Model Improvement - CNN + Attention¶

3.1 Attention mechanism¶

In [ ]:
from tensorflow.keras.layers import MultiHeadAttention
def model2():
    inputs = layers.Input(shape= [SIZE,SIZE,3])
    d1 = down(128,(3,3),False)(inputs)
    d2 = down(128,(3,3),False)(d1)
    d3 = down(256,(3,3),True)(d2)
    d4 = down(512,(3,3),True)(d3)
    d5 = down(512,(3,3),True)(d4)

    # Add attention layer
    attention = MultiHeadAttention(num_heads=2, key_dim=2)
    attention_output = attention(d5, d4)

    #upsampling
    u1 = up(512,(3,3),False)(attention_output)
    u1 = layers.concatenate([u1,d4])
    u2 = up(256,(3,3),False)(u1)
    u2 = layers.concatenate([u2,d3])
    u3 = up(128,(3,3),False)(u2)
    u3 = layers.concatenate([u3,d2])
    u4 = up(128,(3,3),False)(u3)
    u4 = layers.concatenate([u4,d1])
    u5 = up(3,(3,3),False)(u4)
    u5 = layers.concatenate([u5,inputs])
    output = layers.Conv2D(3,(2,2),strides = 1, padding = 'same')(u5)
    return tf.keras.Model(inputs=inputs, outputs=output)

model2 = model2()
model2.summary()
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 sequential_10 (Sequential)     (None, 128, 128, 12  3584        ['input_2[0][0]']                
                                8)                                                                
                                                                                                  
 sequential_11 (Sequential)     (None, 64, 64, 128)  147584      ['sequential_10[0][0]']          
                                                                                                  
 sequential_12 (Sequential)     (None, 32, 32, 256)  296192      ['sequential_11[0][0]']          
                                                                                                  
 sequential_13 (Sequential)     (None, 16, 16, 512)  1182208     ['sequential_12[0][0]']          
                                                                                                  
 sequential_14 (Sequential)     (None, 8, 8, 512)    2361856     ['sequential_13[0][0]']          
                                                                                                  
 multi_head_attention (MultiHea  (None, 8, 8, 512)   8716        ['sequential_14[0][0]',          
 dAttention)                                                      'sequential_13[0][0]']          
                                                                                                  
 sequential_15 (Sequential)     (None, 16, 16, 512)  2359808     ['multi_head_attention[0][0]']   
                                                                                                  
 concatenate_5 (Concatenate)    (None, 16, 16, 1024  0           ['sequential_15[0][0]',          
                                )                                 'sequential_13[0][0]']          
                                                                                                  
 sequential_16 (Sequential)     (None, 32, 32, 256)  2359552     ['concatenate_5[0][0]']          
                                                                                                  
 concatenate_6 (Concatenate)    (None, 32, 32, 512)  0           ['sequential_16[0][0]',          
                                                                  'sequential_12[0][0]']          
                                                                                                  
 sequential_17 (Sequential)     (None, 64, 64, 128)  589952      ['concatenate_6[0][0]']          
                                                                                                  
 concatenate_7 (Concatenate)    (None, 64, 64, 256)  0           ['sequential_17[0][0]',          
                                                                  'sequential_11[0][0]']          
                                                                                                  
 sequential_18 (Sequential)     (None, 128, 128, 12  295040      ['concatenate_7[0][0]']          
                                8)                                                                
                                                                                                  
 concatenate_8 (Concatenate)    (None, 128, 128, 25  0           ['sequential_18[0][0]',          
                                6)                                'sequential_10[0][0]']          
                                                                                                  
 sequential_19 (Sequential)     (None, 256, 256, 3)  6915        ['concatenate_8[0][0]']          
                                                                                                  
 concatenate_9 (Concatenate)    (None, 256, 256, 6)  0           ['sequential_19[0][0]',          
                                                                  'input_2[0][0]']                
                                                                                                  
 conv2d_11 (Conv2D)             (None, 256, 256, 3)  75          ['concatenate_9[0][0]']          
                                                                                                  
==================================================================================================
Total params: 9,611,482
Trainable params: 9,608,922
Non-trainable params: 2,560
__________________________________________________________________________________________________
In [ ]:
model2.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001), loss = 'mean_absolute_error',
              metrics = ['acc'])
In [ ]:
model2.fit(train_low_image, train_high_image, epochs = 15, batch_size = 8,
          validation_data = (validation_low_image,validation_high_image))
Epoch 1/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0235 - acc: 0.8051 - val_loss: 0.0249 - val_acc: 0.7642
Epoch 2/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0250 - acc: 0.7927 - val_loss: 0.0252 - val_acc: 0.8444
Epoch 3/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0233 - acc: 0.7982 - val_loss: 0.0223 - val_acc: 0.8175
Epoch 4/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0229 - acc: 0.8059 - val_loss: 0.0226 - val_acc: 0.7857
Epoch 5/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0219 - acc: 0.8123 - val_loss: 0.0231 - val_acc: 0.7880
Epoch 6/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0223 - acc: 0.8143 - val_loss: 0.0238 - val_acc: 0.7988
Epoch 7/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0220 - acc: 0.8074 - val_loss: 0.0230 - val_acc: 0.7915
Epoch 8/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0226 - acc: 0.8134 - val_loss: 0.0219 - val_acc: 0.8334
Epoch 9/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0210 - acc: 0.8119 - val_loss: 0.0213 - val_acc: 0.8234
Epoch 10/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0207 - acc: 0.8208 - val_loss: 0.0232 - val_acc: 0.7757
Epoch 11/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0216 - acc: 0.8111 - val_loss: 0.0224 - val_acc: 0.8518
Epoch 12/15
88/88 [==============================] - 3s 29ms/step - loss: 0.0217 - acc: 0.8191 - val_loss: 0.0242 - val_acc: 0.8342
Epoch 13/15
88/88 [==============================] - 2s 27ms/step - loss: 0.0207 - acc: 0.8162 - val_loss: 0.0196 - val_acc: 0.8177
Epoch 14/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0205 - acc: 0.8229 - val_loss: 0.0195 - val_acc: 0.8446
Epoch 15/15
88/88 [==============================] - 2s 28ms/step - loss: 0.0196 - acc: 0.8325 - val_loss: 0.0191 - val_acc: 0.8200
Out[ ]:
<keras.callbacks.History at 0x7fd4d4094c10>

3.2 Prediction results¶

In [ ]:
def plot_images(high,low,predicted):
    plt.figure(figsize=(15,15))
    plt.subplot(1,3,1)
    plt.title('High Resolution Image', color = 'green', fontsize = 20)
    plt.imshow(high)
    plt.subplot(1,3,2)
    plt.title('Low Resolution Image', color = 'black', fontsize = 20)
    plt.imshow(low)
    plt.subplot(1,3,3)
    plt.title('Predicted Image', color = 'Red', fontsize = 20)
    plt.imshow(predicted)

    plt.show()

for i in range(1,25):

    predicted = np.clip(model2.predict(test_low_image[i].reshape(1,SIZE, SIZE,3)),0.0,1.0).reshape(SIZE, SIZE,3)
    plot_images(test_high_image[i],test_low_image[i],predicted)
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 49ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 52ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 60ms/step
1/1 [==============================] - 0s 50ms/step
1/1 [==============================] - 0s 54ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 55ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 53ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 49ms/step
1/1 [==============================] - 0s 60ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 60ms/step
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 51ms/step
1/1 [==============================] - 0s 53ms/step
1/1 [==============================] - 0s 50ms/step